{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9aMFvFjcoI_v"
      },
      "outputs": [],
      "source": [
        "# Copyright 2018 The TensorFlow GAN Authors. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# =============================================================================="
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "35cp5a7vN9V8"
      },
      "source": [
        "# TF-GAN Tutorial\n",
        "\n",
        "Tutorial authors: joelshor@, westbrook@"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XSTQ5Flu7FMP"
      },
      "source": [
        "## Colab Prelims\n",
        "\n",
        "\n",
        "### Steps to run this notebook\n",
        "\n",
        "This notebook should be run in Colaboratory. If you are viewing this from GitHub, follow the GitHub instructions. If you are viewing this from Colaboratory, you should skip to the Colaboratory instructions.\n",
        "\n",
        "#### Steps from GitHub\n",
        "\n",
        "1. Navigate your web brower to the main Colaboratory website: https://colab.research.google.com.\n",
        "1. Click the `GitHub` tab.\n",
        "1. In the field marked `Enter a GitHub URL or search by organization or user`, put in the URL of this notebook in GitHub and click the magnifying glass icon next to it.\n",
        "1. Run the notebook in colaboratory by following the instructions below.\n",
        "\n",
        "#### Steps from Colaboratory\n",
        "\n",
        "This colab will run much faster on GPU. To use a Google Cloud\n",
        "GPU:\n",
        "\n",
        "1. Go to `Runtime \u003e Change runtime type`.\n",
        "1. Click `Hardware accelerator`.\n",
        "1. Select `GPU` and click `Save`.\n",
        "1. Click `Connect` in the upper right corner and select `Connect to hosted runtime`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "83-azWpoYsDg"
      },
      "outputs": [],
      "source": [
        "# Check that imports for the rest of the file work.\n",
        "import tensorflow.compat.v1 as tf\n",
        "!pip install tensorflow-gan\n",
        "import tensorflow_gan as tfgan\n",
        "import tensorflow_datasets as tfds\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "# Allow matplotlib images to render immediately.\n",
        "%matplotlib inline\n",
        "tf.logging.set_verbosity(tf.logging.ERROR)  # Disable noisy outputs."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b2xrX4F-OEL7"
      },
      "source": [
        "## Overview\n",
        "\n",
        "This colab will walk you through the basics of using [TF-GAN](https://github.com/tensorflow/gan) to define, train, and evaluate Generative Adversarial Networks (GANs). We describe the library's core features as well as some extra features. This colab assumes a familiarity with TensorFlow's Python API. For more on TensorFlow, please see [TensorFlow tutorials](https://www.tensorflow.org/tutorials/)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JMljl0ZwONgi"
      },
      "source": [
        "## Learning objectives\n",
        "\n",
        "In this Colab, you will learn how to:\n",
        "*   Use TF-GAN Estimators to quickly train a GAN"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pI8zy5Bz65pa"
      },
      "source": [
        "## Unconditional MNIST with GANEstimator\n",
        "\n",
        "This exercise uses TF-GAN's GANEstimator and the MNIST dataset to create a GAN for generating fake handwritten digits.\n",
        "\n",
        "### MNIST\n",
        "\n",
        "The [MNIST dataset](https://wikipedia.org/wiki/MNIST_database) contains tens of thousands of images of handwritten digits. We'll use these images to train a GAN to generate fake images of handwritten digits. This task is small enough that you'll be able to train the GAN in a matter of minutes.\n",
        "\n",
        "### GANEstimator\n",
        "\n",
        "TensorFlow's Estimator API that makes it easy to train models. TF-GAN offers `GANEstimator`, an Estimator for training GANs."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qxrYrU887Mns"
      },
      "source": [
        "### Input Pipeline\n",
        "\n",
        "We set up our input pipeline by defining an `input_fn`. in the \"Train and Eval Loop\" section below we pass this function to our GANEstimator's `train` method to initiate training.  The `input_fn`:\n",
        "\n",
        "1.  Generates the random inputs for the generator.\n",
        "2.  Uses `tensorflow_datasets` to retrieve the MNIST data.\n",
        "3.  Uses the tf.data API to format the data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Zs8kdV0w7Rtq"
      },
      "outputs": [],
      "source": [
        "import tensorflow_datasets as tfds\n",
        "import tensorflow.compat.v1 as tf\n",
        "\n",
        "def input_fn(mode, params):\n",
        "  assert 'batch_size' in params\n",
        "  assert 'noise_dims' in params\n",
        "  bs = params['batch_size']\n",
        "  nd = params['noise_dims']\n",
        "  split = 'train' if mode == tf.estimator.ModeKeys.TRAIN else 'test'\n",
        "  shuffle = (mode == tf.estimator.ModeKeys.TRAIN)\n",
        "  just_noise = (mode == tf.estimator.ModeKeys.PREDICT)\n",
        "  \n",
        "  noise_ds = (tf.data.Dataset.from_tensors(0).repeat()\n",
        "              .map(lambda _: tf.random.normal([bs, nd])))\n",
        "  \n",
        "  if just_noise:\n",
        "    return noise_ds\n",
        "\n",
        "  def _preprocess(element):\n",
        "    # Map [0, 255] to [-1, 1].\n",
        "    images = (tf.cast(element['image'], tf.float32) - 127.5) / 127.5\n",
        "    return images\n",
        "\n",
        "  images_ds = (tfds.load('mnist:3.*.*', split=split)\n",
        "               .map(_preprocess)\n",
        "               .cache()\n",
        "               .repeat())\n",
        "  if shuffle:\n",
        "    images_ds = images_ds.shuffle(\n",
        "        buffer_size=10000, reshuffle_each_iteration=True)\n",
        "  images_ds = (images_ds.batch(bs, drop_remainder=True)\n",
        "               .prefetch(tf.data.experimental.AUTOTUNE))\n",
        "\n",
        "  return tf.data.Dataset.zip((noise_ds, images_ds))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t6aboJBr8Rig"
      },
      "source": [
        "Download the data and sanity check the inputs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zEhgLuGo8OGc"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import tensorflow_datasets as tfds\n",
        "import tensorflow_gan as tfgan\n",
        "import numpy as np\n",
        "\n",
        "params = {'batch_size': 100, 'noise_dims':64}\n",
        "with tf.Graph().as_default():\n",
        "  ds = input_fn(tf.estimator.ModeKeys.TRAIN, params)\n",
        "  numpy_imgs = next(iter(tfds.as_numpy(ds)))[1]\n",
        "img_grid = tfgan.eval.python_image_grid(numpy_imgs, grid_shape=(10, 10))\n",
        "plt.axis('off')\n",
        "plt.imshow(np.squeeze(img_grid))\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4sAetutZ9t93"
      },
      "source": [
        "### Neural Network Architecture\n",
        "\n",
        "To build our GAN we need two separate networks:\n",
        "\n",
        "*  A generator that takes input noise and outputs generated MNIST digits\n",
        "*  A discriminator that takes images and outputs a probability of being real or fake\n",
        "\n",
        "We define functions that build these networks. In the GANEstimator section below we pass the builder functions to the `GANEstimator` constructor. `GANEstimator` handles hooking the generator and discriminator together into the GAN. \n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oZ9n-jw_MG6C"
      },
      "outputs": [],
      "source": [
        "def _dense(inputs, units, l2_weight):\n",
        "  return tf.layers.dense(\n",
        "      inputs, units, None,\n",
        "      kernel_initializer=tf.keras.initializers.glorot_uniform,\n",
        "      kernel_regularizer=tf.keras.regularizers.l2(l=l2_weight),\n",
        "      bias_regularizer=tf.keras.regularizers.l2(l=l2_weight))\n",
        "\n",
        "def _batch_norm(inputs, is_training):\n",
        "  return tf.layers.batch_normalization(\n",
        "      inputs, momentum=0.999, epsilon=0.001, training=is_training)\n",
        "\n",
        "def _deconv2d(inputs, filters, kernel_size, stride, l2_weight):\n",
        "  return tf.layers.conv2d_transpose(\n",
        "      inputs, filters, [kernel_size, kernel_size], strides=[stride, stride], \n",
        "      activation=tf.nn.relu, padding='same',\n",
        "      kernel_initializer=tf.keras.initializers.glorot_uniform,\n",
        "      kernel_regularizer=tf.keras.regularizers.l2(l=l2_weight),\n",
        "      bias_regularizer=tf.keras.regularizers.l2(l=l2_weight))\n",
        "\n",
        "def _conv2d(inputs, filters, kernel_size, stride, l2_weight):\n",
        "  return tf.layers.conv2d(\n",
        "      inputs, filters, [kernel_size, kernel_size], strides=[stride, stride], \n",
        "      activation=None, padding='same',\n",
        "      kernel_initializer=tf.keras.initializers.glorot_uniform,\n",
        "      kernel_regularizer=tf.keras.regularizers.l2(l=l2_weight),\n",
        "      bias_regularizer=tf.keras.regularizers.l2(l=l2_weight))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NHkpn6ks90_R"
      },
      "outputs": [],
      "source": [
        "def unconditional_generator(noise, mode, weight_decay=2.5e-5):\n",
        "  \"\"\"Generator to produce unconditional MNIST images.\"\"\"\n",
        "  is_training = (mode == tf.estimator.ModeKeys.TRAIN)\n",
        "  \n",
        "  net = _dense(noise, 1024, weight_decay)\n",
        "  net = _batch_norm(net, is_training)\n",
        "  net = tf.nn.relu(net)\n",
        "  \n",
        "  net = _dense(net, 7 * 7 * 256, weight_decay)\n",
        "  net = _batch_norm(net, is_training)\n",
        "  net = tf.nn.relu(net)\n",
        "  \n",
        "  net = tf.reshape(net, [-1, 7, 7, 256])\n",
        "  net = _deconv2d(net, 64, 4, 2, weight_decay)\n",
        "  net = _deconv2d(net, 64, 4, 2, weight_decay)\n",
        "  # Make sure that generator output is in the same range as `inputs`\n",
        "  # ie [-1, 1].\n",
        "  net = _conv2d(net, 1, 4, 1, 0.0)\n",
        "  net = tf.tanh(net)\n",
        "\n",
        "  return net"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "w-ZqQ4_thIrP"
      },
      "outputs": [],
      "source": [
        "_leaky_relu = lambda net: tf.nn.leaky_relu(net, alpha=0.01)\n",
        "\n",
        "def unconditional_discriminator(img, unused_conditioning, mode, weight_decay=2.5e-5):\n",
        "  del unused_conditioning\n",
        "  is_training = (mode == tf.estimator.ModeKeys.TRAIN)\n",
        "  \n",
        "  net = _conv2d(img, 64, 4, 2, weight_decay)\n",
        "  net = _leaky_relu(net)\n",
        "  \n",
        "  net = _conv2d(net, 128, 4, 2, weight_decay)\n",
        "  net = _leaky_relu(net)\n",
        "  \n",
        "  net = tf.layers.flatten(net)\n",
        "  \n",
        "  net = _dense(net, 1024, weight_decay)\n",
        "  net = _batch_norm(net, is_training)\n",
        "  net = _leaky_relu(net)\n",
        "  \n",
        "  net = _dense(net, 1, weight_decay)\n",
        "\n",
        "  return net"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OhTAjxnyPS5e"
      },
      "source": [
        "### Evaluating Generative Models, and evaluating GANs\n",
        "\n",
        "\n",
        "TF-GAN provides some standard methods of evaluating generative models. In this example, we measure:\n",
        "\n",
        "*  Inception Score: called `mnist_score` below.\n",
        "*  Frechet Inception Distance\n",
        "\n",
        "We apply a pre-trained classifier to both the real data and the generated data calculate the *Inception Score*.  The Inception Score is designed to measure both quality and diversity. See [Improved Techniques for Training GANs](https://arxiv.org/abs/1606.03498) by Salimans et al for more information about the Inception Score.\n",
        "\n",
        "*Frechet Inception Distance* measures how close the generated image distribution is to the real image distribution.  See [GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium](https://arxiv.org/abs/1706.08500) by Heusel et al for more information about the Frechet Inception distance."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1jF-FW5LPTn6"
      },
      "outputs": [],
      "source": [
        "from tensorflow_gan.examples.mnist import util as eval_util\n",
        "import os\n",
        "\n",
        "def get_eval_metric_ops_fn(gan_model):\n",
        "  real_data_logits = tf.reduce_mean(gan_model.discriminator_real_outputs)\n",
        "  gen_data_logits = tf.reduce_mean(gan_model.discriminator_gen_outputs)\n",
        "  real_mnist_score = eval_util.mnist_score(gan_model.real_data)\n",
        "  generated_mnist_score = eval_util.mnist_score(gan_model.generated_data)\n",
        "  frechet_distance = eval_util.mnist_frechet_distance(\n",
        "      gan_model.real_data, gan_model.generated_data)\n",
        "  return {\n",
        "      'real_data_logits': tf.metrics.mean(real_data_logits),\n",
        "      'gen_data_logits': tf.metrics.mean(gen_data_logits),\n",
        "      'real_mnist_score': tf.metrics.mean(real_mnist_score),\n",
        "      'mnist_score': tf.metrics.mean(generated_mnist_score),\n",
        "      'frechet_distance': tf.metrics.mean(frechet_distance),\n",
        "  }"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kxF2-gWHHaej"
      },
      "source": [
        "### GANEstimator\n",
        "\n",
        "The `GANEstimator` assembles and manages the pieces of the whole GAN model. The `GANEstimator` constructor takes the following compoonents for both the generator and discriminator:\n",
        "\n",
        "*  Network builder functions: we defined these in the \"Neural Network Architecture\" section above.\n",
        "*  Loss functions: here we use the wasserstein loss for both.\n",
        "*  Optimizers: here we use `tf.train.AdamOptimizer` for both generator and discriminator training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OBd8Vg7lHit8"
      },
      "outputs": [],
      "source": [
        "train_batch_size = 32 #@param\n",
        "noise_dimensions = 64 #@param\n",
        "generator_lr = 0.001 #@param\n",
        "discriminator_lr = 0.0002 #@param\n",
        "\n",
        "def gen_opt():\n",
        "  gstep = tf.train.get_or_create_global_step()\n",
        "  base_lr = generator_lr\n",
        "  # Halve the learning rate at 1000 steps.\n",
        "  lr = tf.cond(gstep \u003c 1000, lambda: base_lr, lambda: base_lr / 2.0)\n",
        "  return tf.train.AdamOptimizer(lr, 0.5)\n",
        "\n",
        "gan_estimator = tfgan.estimator.GANEstimator(\n",
        "    generator_fn=unconditional_generator,\n",
        "    discriminator_fn=unconditional_discriminator,\n",
        "    generator_loss_fn=tfgan.losses.wasserstein_generator_loss,\n",
        "    discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,\n",
        "    params={'batch_size': train_batch_size, 'noise_dims': noise_dimensions},\n",
        "    generator_optimizer=gen_opt,\n",
        "    discriminator_optimizer=tf.train.AdamOptimizer(discriminator_lr, 0.5),\n",
        "    get_eval_metric_ops_fn=get_eval_metric_ops_fn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n1uldXfUfstT"
      },
      "source": [
        "### Train and eval loop\n",
        "\n",
        "The `GANEstimator`'s `train()` method initiates GAN training, including the alternating generator and discriminator training phases.\n",
        "\n",
        "The loop in the code below calls `train()` repeatedly in order to periodically display generator output and evaluation results. But note that the code below does not manage the alternation between discriminator and generator: that's all handled automatically by `train()`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AH6gcvcwHvSn"
      },
      "outputs": [],
      "source": [
        "# Disable noisy output.\n",
        "tf.autograph.set_verbosity(0, False)\n",
        "\n",
        "import time\n",
        "steps_per_eval = 500 #@param\n",
        "max_train_steps = 5000 #@param\n",
        "batches_for_eval_metrics = 100 #@param\n",
        "\n",
        "# Used to track metrics.\n",
        "steps = []\n",
        "real_logits, fake_logits = [], []\n",
        "real_mnist_scores, mnist_scores, frechet_distances = [], [], []\n",
        "\n",
        "cur_step = 0\n",
        "start_time = time.time()\n",
        "while cur_step \u003c max_train_steps:\n",
        "  next_step = min(cur_step + steps_per_eval, max_train_steps)\n",
        "\n",
        "  start = time.time()\n",
        "  gan_estimator.train(input_fn, max_steps=next_step)\n",
        "  steps_taken = next_step - cur_step\n",
        "  time_taken = time.time() - start\n",
        "  print('Time since start: %.2f min' % ((time.time() - start_time) / 60.0))\n",
        "  print('Trained from step %i to %i in %.2f steps / sec' % (\n",
        "      cur_step, next_step, steps_taken / time_taken))\n",
        "  cur_step = next_step\n",
        "  \n",
        "  # Calculate some metrics.\n",
        "  metrics = gan_estimator.evaluate(input_fn, steps=batches_for_eval_metrics)\n",
        "  steps.append(cur_step)\n",
        "  real_logits.append(metrics['real_data_logits'])\n",
        "  fake_logits.append(metrics['gen_data_logits'])\n",
        "  real_mnist_scores.append(metrics['real_mnist_score'])\n",
        "  mnist_scores.append(metrics['mnist_score'])\n",
        "  frechet_distances.append(metrics['frechet_distance'])\n",
        "  print('Average discriminator output on Real: %.2f  Fake: %.2f' % (\n",
        "      real_logits[-1], fake_logits[-1]))\n",
        "  print('Inception Score: %.2f / %.2f  Frechet Distance: %.2f' % (\n",
        "      mnist_scores[-1], real_mnist_scores[-1], frechet_distances[-1]))\n",
        "  \n",
        "  # Vizualize some images.\n",
        "  iterator = gan_estimator.predict(\n",
        "      input_fn, hooks=[tf.train.StopAtStepHook(num_steps=21)])\n",
        "  try:\n",
        "    imgs = np.array([next(iterator) for _ in range(20)])\n",
        "  except StopIteration:\n",
        "    pass\n",
        "  tiled = tfgan.eval.python_image_grid(imgs, grid_shape=(2, 10))\n",
        "  plt.axis('off')\n",
        "  plt.imshow(np.squeeze(tiled))\n",
        "  plt.show()\n",
        "  \n",
        "  \n",
        "# Plot the metrics vs step.\n",
        "plt.title('MNIST Frechet distance per step')\n",
        "plt.plot(steps, frechet_distances)\n",
        "plt.figure()\n",
        "plt.title('MNIST Score per step')\n",
        "plt.plot(steps, mnist_scores)\n",
        "plt.plot(steps, real_mnist_scores)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uy1dsvWuwJeS"
      },
      "source": [
        "### Next steps\n",
        "\n",
        "Try [this colab notebook](https://github.com/tensorflow/gan) to train a GAN on Google's Cloud TPU use TF-GAN.\n",
        "\n",
        "\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/brain/python/client:colab_notebook",
        "kind": "private"
      },
      "name": "TF-GAN Tutorial",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "/piper/depot/tensorflow_gan/examples/colab_notebooks/tfgan_tutorial.ipynb",
          "timestamp": 1625147882111
        },
        {
          "file_id": "/piper/depot/tensorflow_gan/examples/colab_notebooks/tfgan_tutorial.ipynb",
          "timestamp": 1625002068968
        },
        {
          "file_id": "/piper/depot/tensorflow_gan/examples/colab_notebooks/tfgan_tutorial.ipynb",
          "timestamp": 1571383618849
        },
        {
          "file_id": "/piper/depot/tensorflow_gan/examples/colab_notebooks/tfgan_tutorial.ipynb",
          "timestamp": 1569547390651
        },
        {
          "file_id": "/piper/depot/tensorflow_gan/examples/colab_notebooks/tfgan_tutorial.ipynb",
          "timestamp": 1559972047311
        },
        {
          "file_id": "/piper/depot/tensorflow_gan/examples/colab_notebooks/tfgan_tutorial.ipynb",
          "timestamp": 1559900570952
        },
        {
          "file_id": "/piper/depot/tensorflow_gan/examples/colab_notebooks/tfgan_tutorial.ipynb",
          "timestamp": 1559897391264
        },
        {
          "file_id": "/piper/depot/tensorflow_gan/examples/colab_notebooks/tfgan_tutorial.ipynb",
          "timestamp": 1559752800451
        },
        {
          "file_id": "/piper/depot/tensorflow_gan/examples/colab_notebooks/tfgan_tutorial.ipynb",
          "timestamp": 1559719883868
        },
        {
          "file_id": "/piper/depot/tensorflow_gan/examples/colab_notebooks/tfgan_tutorial.ipynb",
          "timestamp": 1559717312855
        },
        {
          "file_id": "/piper/depot/tensorflow_gan/examples/colab_notebooks/tfgan_tutorial.ipynb",
          "timestamp": 1559641947244
        },
        {
          "file_id": "14r58gghjLTBBQVoSFOBdPsvj-I1G6nbd",
          "timestamp": 1549819781952
        },
        {
          "file_id": "0Bz8X96EaC_2-ZW9odlhSOEFXdWs",
          "timestamp": 1493398103910
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
